1
0
mirror of https://github.com/httpie/cli.git synced 2025-01-10 00:28:12 +02:00

Refactor client

This commit is contained in:
Jakub Roztocil 2019-09-01 11:38:14 +02:00
parent a34b3d9d87
commit 4dffac7a25
3 changed files with 53 additions and 31 deletions

View File

@ -169,6 +169,7 @@ content_processing = parser.add_argument_group(
content_processing.add_argument( content_processing.add_argument(
'--compress', '-x', '--compress', '-x',
action='count', action='count',
default=0,
help=""" help="""
Content compressed (encoded) with Deflate algorithm. Content compressed (encoded) with Deflate algorithm.
The Content-Encoding header is set to deflate. The Content-Encoding header is set to deflate.

View File

@ -17,18 +17,13 @@ from httpie.utils import repr_dict
try: try:
# https://urllib3.readthedocs.io/en/latest/security.html
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
import urllib3 import urllib3
# <https://urllib3.readthedocs.io/en/latest/security.html>
urllib3.disable_warnings() urllib3.disable_warnings()
except (ImportError, AttributeError): except (ImportError, AttributeError):
# In some rare cases, the user may have an old version of the requests
# or urllib3, and there is no method called "disable_warnings." In these
# cases, we don't need to call the method.
# They may get some noisy output but execution shouldn't die. Move on.
pass pass
FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8' FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8'
JSON_CONTENT_TYPE = 'application/json' JSON_CONTENT_TYPE = 'application/json'
JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*' JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*'
@ -49,9 +44,16 @@ def max_headers(limit):
class HTTPieHTTPAdapter(HTTPAdapter): class HTTPieHTTPAdapter(HTTPAdapter):
def __init__(self, ssl_version=None, compress=0, **kwargs): def __init__(
self,
ssl_version=None,
compression_enabled=False,
compress_always=False,
**kwargs,
):
self._ssl_version = ssl_version self._ssl_version = ssl_version
self._compress = compress self._compression_enabled = compression_enabled
self._compress_always = compress_always
super().__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs): def init_poolmanager(self, *args, **kwargs):
@ -59,34 +61,50 @@ class HTTPieHTTPAdapter(HTTPAdapter):
super().init_poolmanager(*args, **kwargs) super().init_poolmanager(*args, **kwargs)
def send(self, request: requests.PreparedRequest, **kwargs): def send(self, request: requests.PreparedRequest, **kwargs):
if self._compress and request.body: if request.body and self._compression_enabled:
self._compress_body(request, self._compress) self._compress_body(request, always=self._compress_always)
return super().send(request, **kwargs) return super().send(request, **kwargs)
@staticmethod @staticmethod
def _compress_body(request: requests.PreparedRequest, compress: int): def _compress_body(request: requests.PreparedRequest, always: bool):
deflater = zlib.compressobj() deflater = zlib.compressobj()
if isinstance(request.body, bytes): body_bytes = (
deflated_data = deflater.compress(request.body) request.body
else: if isinstance(request.body, bytes)
deflated_data = deflater.compress(request.body.encode()) else request.body.encode()
)
deflated_data = deflater.compress(body_bytes)
deflated_data += deflater.flush() deflated_data += deflater.flush()
if len(deflated_data) < len(request.body) or compress > 1: is_economical = len(deflated_data) < len(body_bytes)
if is_economical or always:
request.body = deflated_data request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate' request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data)) request.headers['Content-Length'] = str(len(deflated_data))
def get_requests_session(ssl_version: str, compress: int) -> requests.Session: def build_requests_session(
ssl_version: str,
compress_arg: int,
) -> requests.Session:
requests_session = requests.Session() requests_session = requests.Session()
adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress)
for prefix in ['http://', 'https://']:
requests_session.mount(prefix, adapter)
for cls in plugin_manager.get_transport_plugins(): # Install our adapter.
transport_plugin = cls() adapter = HTTPieHTTPAdapter(
requests_session.mount(prefix=transport_plugin.prefix, ssl_version=ssl_version,
adapter=transport_plugin.get_adapter()) compression_enabled=compress_arg > 0,
compress_always=compress_arg > 1,
)
requests_session.mount('http://', adapter)
requests_session.mount('https://', adapter)
# Install adapters from plugins.
for plugin_cls in plugin_manager.get_transport_plugins():
transport_plugin = plugin_cls()
requests_session.mount(
prefix=transport_plugin.prefix,
adapter=transport_plugin.get_adapter(),
)
return requests_session return requests_session
@ -100,12 +118,15 @@ def get_response(
if args.ssl_version: if args.ssl_version:
ssl_version = SSL_VERSION_ARG_MAPPING[args.ssl_version] ssl_version = SSL_VERSION_ARG_MAPPING[args.ssl_version]
requests_session = get_requests_session(ssl_version, args.compress) requests_session = build_requests_session(
ssl_version=ssl_version,
compress_arg=args.compress
)
requests_session.max_redirects = args.max_redirects requests_session.max_redirects = args.max_redirects
with max_headers(args.max_headers): with max_headers(args.max_headers):
if not args.session and not args.session_read_only: if not args.session and not args.session_read_only:
kwargs = get_requests_kwargs(args) kwargs = make_requests_kwargs(args)
if args.debug: if args.debug:
dump_request(kwargs) dump_request(kwargs)
response = requests_session.request(**kwargs) response = requests_session.request(**kwargs)
@ -142,7 +163,7 @@ def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict:
return final_headers return final_headers
def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict: def make_default_headers(args: argparse.Namespace) -> RequestHeadersDict:
default_headers = RequestHeadersDict({ default_headers = RequestHeadersDict({
'User-Agent': DEFAULT_UA 'User-Agent': DEFAULT_UA
}) })
@ -160,7 +181,7 @@ def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict:
return default_headers return default_headers
def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict: def make_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict:
""" """
Translate our `args` into `requests.request` keyword arguments. Translate our `args` into `requests.request` keyword arguments.
@ -177,7 +198,7 @@ def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict:
data = '' data = ''
# Finalize headers. # Finalize headers.
headers = get_default_headers(args) headers = make_default_headers(args)
if base_headers: if base_headers:
headers.update(base_headers) headers.update(base_headers)
headers.update(args.headers) headers.update(args.headers)

View File

@ -36,7 +36,7 @@ def get_response(
aspects of the session to the request. aspects of the session to the request.
""" """
from .client import get_requests_kwargs, dump_request from .client import make_requests_kwargs, dump_request
if os.path.sep in session_name: if os.path.sep in session_name:
path = os.path.expanduser(session_name) path = os.path.expanduser(session_name)
else: else:
@ -56,7 +56,7 @@ def get_response(
session = Session(path) session = Session(path)
session.load() session.load()
kwargs = get_requests_kwargs(args, base_headers=session.headers) kwargs = make_requests_kwargs(args, base_headers=session.headers)
if args.debug: if args.debug:
dump_request(kwargs) dump_request(kwargs)
session.update_headers(kwargs['headers']) session.update_headers(kwargs['headers'])