diff --git a/httpie/core.py b/httpie/core.py index 2079a914..677f4217 100644 --- a/httpie/core.py +++ b/httpie/core.py @@ -68,41 +68,21 @@ def decode_args(args, stdin_encoding): ] -def main(args=sys.argv[1:], env=Environment(), error=None): - """Run the main program and write the output to ``env.stdout``. +def program(args, env, log_error): + """ + The main program without error handling - Return exit status code. + :param args: parsed args (argparse.Namespace) + :type env: Environment + :param log_error: error log function + :return: status code """ - args = decode_args(args, env.stdin_encoding) - plugin_manager.load_installed_plugins() - - from httpie.cli import parser - - if env.config.default_options: - args = env.config.default_options + args - - def _error(msg, *args, **kwargs): - msg = msg % args - level = kwargs.get('level', 'error') - env.stderr.write('\nhttp: %s: %s\n' % (level, msg)) - - if error is None: - error = _error - - debug = '--debug' in args - traceback = debug or '--traceback' in args exit_status = ExitStatus.OK - - if debug: - print_debug_info(env) - if args == ['--debug']: - return exit_status - downloader = None - try: - args = parser.parse_args(args=args, env=env) + show_traceback = args.debug or args.traceback + try: if args.download: args.follow = True # --download implies --follow. downloader = Downloader( @@ -126,10 +106,10 @@ def main(args=sys.argv[1:], env=Environment(), error=None): follow=args.follow ) if not env.stdout_isatty and exit_status != ExitStatus.OK: - error('HTTP %s %s', - response.raw.status, - response.raw.reason, - level='warning') + log_error( + 'HTTP %s %s', response.raw.status, response.raw.reason, + level='warning' + ) write_stream_kwargs = { 'stream': build_output_stream( @@ -148,7 +128,7 @@ def main(args=sys.argv[1:], env=Environment(), error=None): else: write_stream(**write_stream_kwargs) except IOError as e: - if not traceback and e.errno == errno.EPIPE: + if not show_traceback and e.errno == errno.EPIPE: # Ignore broken pipes unless --traceback. env.stderr.write('\n') else: @@ -165,40 +145,11 @@ def main(args=sys.argv[1:], env=Environment(), error=None): downloader.finish() if downloader.interrupted: exit_status = ExitStatus.ERROR - error('Incomplete download: size=%d; downloaded=%d' % ( + log_error('Incomplete download: size=%d; downloaded=%d' % ( downloader.status.total_size, downloader.status.downloaded )) - - except KeyboardInterrupt: - if traceback: - raise - env.stderr.write('\n') - exit_status = ExitStatus.ERROR - except SystemExit as e: - if e.code != ExitStatus.OK: - if traceback: - raise - env.stderr.write('\n') - exit_status = ExitStatus.ERROR - except requests.Timeout: - exit_status = ExitStatus.ERROR_TIMEOUT - error('Request timed out (%ss).', args.timeout) - except requests.TooManyRedirects: - exit_status = ExitStatus.ERROR_TOO_MANY_REDIRECTS - error('Too many redirects (--max-redirects=%s).', args.max_redirects) - except Exception as e: - # TODO: Better distinction between expected and unexpected errors. - if traceback: - raise - msg = str(e) - if hasattr(e, 'request'): - request = e.request - if hasattr(request, 'url'): - msg += ' while doing %s request to URL: %s' % ( - request.method, request.url) - error('%s: %s', type(e).__name__, msg) - exit_status = ExitStatus.ERROR + return exit_status finally: if downloader and not downloader.finished: @@ -208,4 +159,91 @@ def main(args=sys.argv[1:], env=Environment(), error=None): args.output_file_specified): args.output_file.close() + +def main(args=sys.argv[1:], env=Environment(), custom_log_error=None): + """ + The main function. + + Pre-process args, handle some special type of invocations, and run the main + program with error handling. + + Return exit status code. + + """ + args = decode_args(args, env.stdin_encoding) + plugin_manager.load_installed_plugins() + + def log_error(msg, *args, level='error'): + msg = msg % args + env.stderr.write('\nhttp: %s: %s\n' % (level, msg)) + + from httpie.cli import parser + + if env.config.default_options: + args = env.config.default_options + args + + if custom_log_error: + log_error = custom_log_error + + include_debug_info = '--debug' in args + include_traceback = include_debug_info or '--traceback' in args + + if include_debug_info: + print_debug_info(env) + if args == ['--debug']: + return ExitStatus.OK + + exit_status = ExitStatus.OK + + try: + parsed_args = parser.parse_args(args=args, env=env) + except KeyboardInterrupt: + env.stderr.write('\n') + if include_traceback: + raise + exit_status = ExitStatus.ERROR + except SystemExit as e: + if e.code != ExitStatus.OK: + env.stderr.write('\n') + if include_traceback: + raise + exit_status = ExitStatus.ERROR + else: + try: + exit_status = program( + args=parsed_args, + env=env, + log_error=log_error, + ) + except KeyboardInterrupt: + env.stderr.write('\n') + if include_traceback: + raise + exit_status = ExitStatus.ERROR + except SystemExit as e: + if e.code != ExitStatus.OK: + env.stderr.write('\n') + if include_traceback: + raise + exit_status = ExitStatus.ERROR + except requests.Timeout: + exit_status = ExitStatus.ERROR_TIMEOUT + log_error('Request timed out (%ss).', parsed_args.timeout) + except requests.TooManyRedirects: + exit_status = ExitStatus.ERROR_TOO_MANY_REDIRECTS + log_error('Too many redirects (--max-redirects=%s).', + parsed_args.max_redirects) + except Exception as e: + # TODO: Further distinction between expected and unexpected errors. + msg = str(e) + if hasattr(e, 'request'): + request = e.request + if hasattr(request, 'url'): + msg += ' while doing %s request to URL: %s' % ( + request.method, request.url) + log_error('%s: %s', type(e).__name__, msg) + if include_traceback: + raise + exit_status = ExitStatus.ERROR + return exit_status diff --git a/tests/test_errors.py b/tests/test_errors.py index 1d7cf24c..040e889c 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -3,6 +3,7 @@ from pytest import raises from requests import Request, Timeout from requests.exceptions import ConnectionError +from httpie import ExitStatus from httpie.core import main error_msg = None @@ -17,8 +18,8 @@ def test_error(get_response): exc = ConnectionError('Connection aborted') exc.request = Request(method='GET', url='http://www.google.com') get_response.side_effect = exc - ret = main(['--ignore-stdin', 'www.google.com'], error=error) - assert ret == 1 + ret = main(['--ignore-stdin', 'www.google.com'], custom_log_error=error) + assert ret == ExitStatus.ERROR assert error_msg == ( 'ConnectionError: ' 'Connection aborted while doing GET request to URL: ' @@ -43,6 +44,6 @@ def test_timeout(get_response): exc = Timeout('Request timed out') exc.request = Request(method='GET', url='http://www.google.com') get_response.side_effect = exc - ret = main(['--ignore-stdin', 'www.google.com'], error=error) - assert ret == 2 + ret = main(['--ignore-stdin', 'www.google.com'], custom_log_error=error) + assert ret == ExitStatus.ERROR_TIMEOUT assert error_msg == 'Request timed out (30s).'