Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 132 additions & 33 deletions devbin/mujin_webstackclientpy_generategraphclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,34 @@ def _DereferenceType(graphType):
return graphType


def _CleanDocstring(docstring):
"""Clean up docstring formatting to match ruff standards."""
if not docstring:
return docstring
# split into lines and strip trailing whitespace
lines = [line.rstrip() for line in docstring.split('\n')]
# remove leading empty lines
while lines and not lines[0]:
lines.pop(0)
# remove trailing empty lines
while lines and not lines[-1]:
lines.pop()
# collapse multiple consecutive empty lines into single empty lines
resultLines = []
isPreviousEmpty = False
for line in lines:
if line:
resultLines.append(line)
isPreviousEmpty = False
elif not isPreviousEmpty:
resultLines.append('')
isPreviousEmpty = True
return '\n'.join(resultLines)


def _IndentNewlines(string, indent=' ' * 5):
"""Indent new lines in a string. Used for multi-line descriptions."""
return string.replace('\n', '\n' + indent)
return _CleanDocstring(string).replace('\n', '\n' + indent)


def _FormatTypeForDocstring(typeName):
Expand All @@ -64,13 +89,42 @@ def _FormatTypeForDocstring(typeName):
return _typeName


def _FormatTypeForAnnotation(typeName, isNullable=False):
"""Converts GraphQL types to Python type annotations."""
_typeName = str(typeName).replace('!', '')
if _typeName == 'String':
pythonType = 'str'
elif _typeName == 'Int':
pythonType = 'int'
elif _typeName == 'Boolean':
pythonType = 'bool'
elif _typeName == 'Float':
pythonType = 'float'
elif _typeName == 'Void':
# Void functions return None in Python
return 'None'
elif _typeName.startswith('[') and _typeName.endswith(']'):
# handle list types like [String!] -> List[str]
innerType = _typeName[1:-1].replace('!', '')
innerPythonType = _FormatTypeForAnnotation(innerType, False)
pythonType = 'List[%s]' % innerPythonType
else:
# for complex types, use Any
pythonType = 'Any'
# wrap in Optional if nullable
if isNullable:
pythonType = 'Optional[%s]' % pythonType
return pythonType


def _DiscoverType(graphType):
baseFieldType = _DereferenceType(graphType)
baseFieldTypeName = '%s' % baseFieldType
return {
'typeName': '%s' % graphType,
'baseTypeName': '%s' % baseFieldType,
'description': baseFieldType.description.strip(),
'isNullable': not isinstance(graphType, graphql.GraphQLNonNull),
}


Expand Down Expand Up @@ -113,28 +167,66 @@ def _PrintMethod(queryOrMutationOrSubscription, operationName, parameters, descr
builtinParameterNamesRequired = ('callbackFunction',)
builtinParameterNamesOptional = ('fields',)
builtinParameterNames = builtinParameterNamesRequired + builtinParameterNamesOptional
operationParametersRequired = []
operationParametersOptional = []

# build parameter list with type annotations
parameterList = []

# add builtin required parameters
for parameterName in builtinParameterNamesRequired:
if parameterName == 'callbackFunction':
parameterList.append('callbackFunction: Callable[[Optional[Any], Optional[Dict[str, Any]]], None]')
else:
parameterList.append(parameterName)

# add operation parameters (required and optional)
for parameter in parameters:
if parameter['parameterName'] in builtinParameterNames:
continue

if parameter['parameterDefaultValue'] is not None:
parameterType = parameter['parameterType']
if parameterType == 'String':
operationParametersOptional.append("%s='%s'" % (parameter['parameterName'], str(parameter['parameterDefaultValue'])))
# parameter has default value - don't wrap in Optional
parameterType = _FormatTypeForAnnotation(parameter['parameterType'], False)
if parameter['parameterType'] == 'String':
parameterList.append("%s: %s = '%s'" % (parameter['parameterName'], parameterType, str(parameter['parameterDefaultValue'])))
else:
operationParametersOptional.append('%s=%s' % (parameter['parameterName'], str(parameter['parameterDefaultValue'])))
continue
if parameter['parameterNullable'] is True:
operationParametersOptional.append('%s=None' % parameter['parameterName'])
continue
operationParametersRequired.append('%s' % parameter['parameterName'])
parameterList.append('%s: %s = %s' % (parameter['parameterName'], parameterType, str(parameter['parameterDefaultValue'])))
elif parameter['parameterNullable'] is True:
# parameter is optional - wrap in Optional
parameterType = _FormatTypeForAnnotation(parameter['parameterType'], True)
parameterList.append('%s: %s = None' % (parameter['parameterName'], parameterType))
else:
# parameter is required - don't wrap in Optional
parameterType = _FormatTypeForAnnotation(parameter['parameterType'], False)
parameterList.append('%s: %s' % (parameter['parameterName'], parameterType))

# add builtin optional parameters
for parameterName in builtinParameterNamesOptional:
if parameterName == 'fields':
parameterList.append('fields: Optional[Union[List[str], Dict[str, Any]]] = None')
elif parameterName == 'timeout':
parameterList.append('timeout: Optional[float] = None')
else:
parameterList.append('%s: Optional[Any] = None' % parameterName)

fullParameterList = list(builtinParameterNamesRequired) + operationParametersRequired + operationParametersOptional + ['%s=None' % name for name in builtinParameterNamesOptional]
print(' def %s(self, %s):' % (operationName, ', '.join(fullParameterList)))
# determine return type
if queryOrMutationOrSubscription == 'subscription':
finalReturnType = 'Subscription'
else:
finalReturnType = _FormatTypeForAnnotation(returnType['typeName'], returnType['isNullable'])

# print method signature with type annotations
if parameterList:
print(' def %s(' % operationName)
print(' self,')
for param in parameterList[:-1]:
print(' %s,' % param)
print(' %s,' % parameterList[-1]) # last parameter gets trailing comma
print(' ) -> %s:' % finalReturnType)
else:
print(' def %s(self) -> %s:' % (operationName, finalReturnType))

if description:
print(' """%s' % description)
print(' """%s' % _CleanDocstring(description))
else:
print(' """')
print('')
Expand All @@ -155,7 +247,7 @@ def _PrintMethod(queryOrMutationOrSubscription, operationName, parameters, descr
isOptionalString = ', optional' if parameter['parameterNullable'] else ''
print(' %s (%s%s):' % (parameter['parameterName'], _FormatTypeForDocstring(parameter['parameterType']), isOptionalString), end='')
if parameter['parameterDescription']:
print(' %s' % _IndentNewlines(parameter['parameterDescription']))
print(' %s' % _IndentNewlines(_CleanDocstring(parameter['parameterDescription'])))
else:
print('')
print(' fields (list or dict, optional): Specifies a subset of fields to return.')
Expand All @@ -165,27 +257,33 @@ def _PrintMethod(queryOrMutationOrSubscription, operationName, parameters, descr
print(' Returns:')
print(' %s:' % (_FormatTypeForDocstring(returnType['typeName'])), end='')
if returnType['description']:
print(' %s' % _IndentNewlines(returnType['description']))
print(' %s' % _IndentNewlines(_CleanDocstring(returnType['description'])))
else:
print('')
print(' """')

if deprecationReason:
print(' warnings.warn(\'"%s" is deprecated. %s\', DeprecationWarning, stacklevel=2)' % (operationName, deprecationReason))

print(' parameterNameTypeValues = [')
for parameter in parameters:
if parameter['parameterName'] in builtinParameterNames:
continue
print(" ('%s', '%s', %s)," % (parameter['parameterName'], parameter['parameterType'], parameter['parameterName']))
print(' ]')
# check if there are any parameters to add
if any(param['parameterName'] not in builtinParameterNames for param in parameters):
print(' parameterNameTypeValues: List[Tuple[str, str, Any]] = [')
for parameter in parameters:
if parameter['parameterName'] in builtinParameterNames:
continue
print(" ('%s', '%s', %s)," % (parameter['parameterName'], parameter['parameterType'], parameter['parameterName']))
print(' ]')
else:
print(' parameterNameTypeValues: List[Tuple[str, str, Any]] = []')

if queryOrMutationOrSubscription in ('query', 'mutation'):
print(
" return self._CallSimpleGraphAPI('%s', operationName='%s', parameterNameTypeValues=parameterNameTypeValues, returnType='%s', fields=fields, timeout=timeout)" % (queryOrMutationOrSubscription, operationName, returnType['baseTypeName']),
)
elif queryOrMutationOrSubscription == 'subscription':
print(" return self._CallSubscribeGraphAPI(operationName='%s', parameterNameTypeValues=parameterNameTypeValues, returnType='%s', callbackFunction=callbackFunction, fields=fields)" % (operationName, returnType['baseTypeName']))
print(
" return self._CallSubscribeGraphAPI(operationName='%s', parameterNameTypeValues=parameterNameTypeValues, returnType='%s', callbackFunction=callbackFunction, fields=fields)" % (operationName, returnType['baseTypeName']),
)


def _PrintClient(serverVersion, queryMethods, mutationMethods, subscriptionMethods):
Expand All @@ -197,26 +295,25 @@ def _PrintClient(serverVersion, queryMethods, mutationMethods, subscriptionMetho
print('#')
print('')
print('import warnings')
print('from typing import Any, Dict, List, Optional, Union, Callable, Tuple')
print('')
print('from .webstackgraphclientutils import GraphClientBase')
print('from .webstackgraphclientutils import UseLazyGraphQuery')
print('from .controllerwebclientraw import Subscription')
print('')
print('class GraphQueries:')
print('')
print('class GraphQueries:')
for queryMethod in queryMethods:
_PrintMethod('query', **queryMethod)
print('')
print('')
print('class GraphMutations:')
print('')
for mutationMethod in mutationMethods:
_PrintMethod('mutation', **mutationMethod)
print('')
print('')
print('class GraphSubscriptions:')
print('')
print(' def Unsubscribe(self, subscription: Subscription):')
print(' def Unsubscribe(self, subscription: Subscription) -> None:')
print(' """')
print(' Cancel an actively running subscription instance.')
print('')
Expand All @@ -232,30 +329,32 @@ def _PrintClient(serverVersion, queryMethods, mutationMethods, subscriptionMetho
print('class GraphQueriesClient(GraphClientBase, GraphQueries):')
print(' pass')
print('')
print('')
print('class GraphMutationsClient(GraphClientBase, GraphMutations):')
print(' pass')
print('')
print('')
print('class GraphSubscriptionsClient(GraphClientBase, GraphSubscriptions):')
print(' pass')
print('')
print('class GraphClient(GraphClientBase, GraphQueries, GraphMutations, GraphSubscriptions):')
print('')
print('class GraphClient(GraphClientBase, GraphQueries, GraphMutations, GraphSubscriptions):')
print(' @property')
print(' def queries(self):')
print(' def queries(self) -> GraphQueriesClient:')
print(' return GraphQueriesClient(self._webclient)')
print('')
print(' @property')
print(' def mutations(self):')
print(' def mutations(self) -> GraphMutationsClient:')
print(' return GraphMutationsClient(self._webclient)')
print('')
print(' @property')
print(' def subscriptions(self):')
print(' def subscriptions(self) -> GraphSubscriptionsClient:')
print(' return GraphSubscriptionsClient(self._webclient)')
print('')
print('')
print('#')
print('# DO NOT EDIT, THIS FILE WAS AUTO-GENERATED, SEE HEADER')
print('#')
print('')


def _Main():
Expand Down
43 changes: 34 additions & 9 deletions python/mujinwebstackclient/controllerwebclientraw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import websockets
from requests import auth as requests_auth
from requests import adapters as requests_adapters
from typing import Optional, Callable
from typing import Optional, Callable, Dict, Any, Union, List
from urllib.parse import urlparse

import websockets.asyncio
Expand Down Expand Up @@ -157,7 +157,7 @@ class ControllerWebClientRaw(object):
_subscriptionLock: threading.Lock # Lock protecting _webSocket and _subscriptions
_backgroundThread: BackgroundThread = None # The background thread to handle async operations

def __init__(self, baseurl, username, password, locale=None, author=None, userAgent=None, additionalHeaders=None, unixEndpoint=None):
def __init__(self, baseurl: str, username: str, password: str, locale: Optional[str] = None, author: Optional[str] = None, userAgent: Optional[str] = None, additionalHeaders: Optional[Dict[str, str]] = None, unixEndpoint: Optional[str] = None) -> None:
self._baseurl = baseurl
self._username = username
self._password = password
Expand Down Expand Up @@ -241,7 +241,14 @@ def SetUserAgent(self, userAgent=None):
else:
self._headers.pop('User-Agent', None)

def Request(self, method, path, timeout=5, headers=None, **kwargs):
def Request(
self,
method: str,
path: str,
timeout: float = 5,
headers: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> requests.Response:
if timeout < 1e-6:
raise WebstackClientError(_('Timeout value (%s sec) is too small') % timeout)

Expand All @@ -263,7 +270,19 @@ def Request(self, method, path, timeout=5, headers=None, **kwargs):
return response

# Python port of the javascript API Call function
def APICall(self, method, path='', params=None, fields=None, data=None, headers=None, expectedStatusCode=None, files=None, timeout=5, apiVersion='v1'):
def APICall(
self,
method: str,
path: str = '',
params: Optional[Dict[str, Any]] = None,
fields: Optional[Union[List[str], Dict[str, Any]]] = None,
data: Optional[Union[str, Dict[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None,
expectedStatusCode: Optional[int] = None,
files: Optional[Dict[str, Any]] = None,
timeout: float = 5,
apiVersion: str = 'v1',
) -> Any:
path = '/api/%s/%s' % (apiVersion, path.lstrip('/'))
if apiVersion == 'v1' and not path.endswith('/'):
path += '/'
Expand Down Expand Up @@ -302,7 +321,7 @@ def APICall(self, method, path='', params=None, fields=None, data=None, headers=

# Try to parse response
raw = response.content.decode('utf-8', 'replace').strip()
content = None
content: Optional[Dict[str, Any]] = None
if len(raw) > 0:
try:
content = json.loads(raw)
Expand Down Expand Up @@ -338,7 +357,13 @@ def APICall(self, method, path='', params=None, fields=None, data=None, headers=

return content

def CallGraphAPI(self, query, variables=None, headers=None, timeout=5.0):
def CallGraphAPI(
self,
query: str,
variables: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: float = 5.0,
) -> Dict[str, Any]:
# prepare the headers
if headers is None:
headers = {}
Expand Down Expand Up @@ -368,7 +393,7 @@ def CallGraphAPI(self, query, variables=None, headers=None, timeout=5.0):
raise ControllerGraphClientException(_('Unexpected server response %d: %s') % (statusCode, raw), statusCode=statusCode, response=response)

# decode the response content
content = None
content: Optional[Dict[str, Any]] = None
if len(raw) > 0:
try:
content = json.loads(raw)
Expand All @@ -377,8 +402,8 @@ def CallGraphAPI(self, query, variables=None, headers=None, timeout=5.0):

# raise any error returned
if content is not None and 'errors' in content and len(content['errors']) > 0:
message = content['errors'][0].get('message', raw)
errorCode = None
message: str = content['errors'][0].get('message', raw)
errorCode: Optional[str] = None
if 'extensions' in content['errors'][0]:
errorCode = content['errors'][0]['extensions'].get('errorCode', None)
raise ControllerGraphClientException(message, statusCode=statusCode, content=content, response=response, errorCode=errorCode)
Expand Down
Loading