1
0
mirror of https://github.com/httpie/cli.git synced 2025-08-10 22:42:05 +02:00

Prevent data race happening between select.select and file.read()

This commit is contained in:
Batuhan Taskaya
2022-02-25 15:42:13 +03:00
parent 55087a901e
commit b0f5b8ab26

View File

@@ -2,6 +2,8 @@ import sys
import os import os
import zlib import zlib
import functools import functools
import time
import threading
from typing import Any, Callable, IO, Iterable, Optional, Tuple, Union, TYPE_CHECKING from typing import Any, Callable, IO, Iterable, Optional, Tuple, Union, TYPE_CHECKING
from urllib.parse import urlencode from urllib.parse import urlencode
@@ -22,12 +24,20 @@ class ChunkedStream:
class ChunkedUploadStream(ChunkedStream): class ChunkedUploadStream(ChunkedStream):
def __init__(self, stream: Iterable, callback: Callable): def __init__(
self,
stream: Iterable,
callback: Callable,
event: Optional[threading.Event] = None
) -> None:
self.callback = callback self.callback = callback
self.stream = stream self.stream = stream
self.event = event
def __iter__(self) -> Iterable[Union[str, bytes]]: def __iter__(self) -> Iterable[Union[str, bytes]]:
for chunk in self.stream: for chunk in self.stream:
if self.event:
self.event.set()
self.callback(chunk) self.callback(chunk)
yield chunk yield chunk
@@ -35,12 +45,19 @@ class ChunkedUploadStream(ChunkedStream):
class ChunkedMultipartUploadStream(ChunkedStream): class ChunkedMultipartUploadStream(ChunkedStream):
chunk_size = 100 * 1024 chunk_size = 100 * 1024
def __init__(self, encoder: 'MultipartEncoder'): def __init__(
self,
encoder: 'MultipartEncoder',
event: Optional[threading.Event] = None
) -> None:
self.encoder = encoder self.encoder = encoder
self.event = event
def __iter__(self) -> Iterable[Union[str, bytes]]: def __iter__(self) -> Iterable[Union[str, bytes]]:
while True: while True:
chunk = self.encoder.read(self.chunk_size) chunk = self.encoder.read(self.chunk_size)
if self.event:
self.event.set()
if not chunk: if not chunk:
break break
yield chunk yield chunk
@@ -80,7 +97,7 @@ def is_stdin(file: IO) -> bool:
READ_THRESHOLD = float(os.getenv("HTTPIE_STDIN_READ_WARN_THRESHOLD", 10.0)) READ_THRESHOLD = float(os.getenv("HTTPIE_STDIN_READ_WARN_THRESHOLD", 10.0))
def observe_stdin_for_data_thread(env: Environment, file: IO) -> None: def observe_stdin_for_data_thread(env: Environment, file: IO, read_event: threading.Event) -> None:
# Windows unfortunately does not support select() operation # Windows unfortunately does not support select() operation
# on regular files, like stdin in our use case. # on regular files, like stdin in our use case.
# https://docs.python.org/3/library/select.html#select.select # https://docs.python.org/3/library/select.html#select.select
@@ -92,12 +109,9 @@ def observe_stdin_for_data_thread(env: Environment, file: IO) -> None:
if READ_THRESHOLD == 0: if READ_THRESHOLD == 0:
return None return None
import select def worker(event: threading.Event) -> None:
import threading time.sleep(READ_THRESHOLD)
if not event.is_set():
def worker():
can_read, _, _ = select.select([file], [], [], READ_THRESHOLD)
if not can_read:
env.stderr.write( env.stderr.write(
f'> warning: no stdin data read in {READ_THRESHOLD}s ' f'> warning: no stdin data read in {READ_THRESHOLD}s '
f'(perhaps you want to --ignore-stdin)\n' f'(perhaps you want to --ignore-stdin)\n'
@@ -105,11 +119,28 @@ def observe_stdin_for_data_thread(env: Environment, file: IO) -> None:
) )
thread = threading.Thread( thread = threading.Thread(
target=worker target=worker,
args=(read_event,)
) )
thread.start() thread.start()
def _read_file_with_selectors(file: IO, read_event: threading.Event) -> bytes:
if is_windows or not is_stdin(file):
return as_bytes(file.read())
import select
# Try checking whether there is any incoming data for READ_THRESHOLD
# seconds. If there isn't anything in the given period, issue
# a warning about a misusage.
read_selectors, _, _ = select.select([file], [], [], READ_THRESHOLD)
if read_selectors:
read_event.set()
return as_bytes(file.read())
def _prepare_file_for_upload( def _prepare_file_for_upload(
env: Environment, env: Environment,
file: Union[IO, 'MultipartEncoder'], file: Union[IO, 'MultipartEncoder'],
@@ -117,9 +148,11 @@ def _prepare_file_for_upload(
chunked: bool = False, chunked: bool = False,
content_length_header_value: Optional[int] = None, content_length_header_value: Optional[int] = None,
) -> Union[bytes, IO, ChunkedStream]: ) -> Union[bytes, IO, ChunkedStream]:
read_event = threading.Event()
if not super_len(file): if not super_len(file):
if is_stdin(file): if is_stdin(file):
observe_stdin_for_data_thread(env, file) observe_stdin_for_data_thread(env, file, read_event)
# Zero-length -> assume stdin. # Zero-length -> assume stdin.
if content_length_header_value is None and not chunked: if content_length_header_value is None and not chunked:
# Read the whole stdin to determine `Content-Length`. # Read the whole stdin to determine `Content-Length`.
@@ -129,7 +162,7 @@ def _prepare_file_for_upload(
# something like --no-chunked. # something like --no-chunked.
# This would be backwards-incompatible so wait until v3.0.0. # This would be backwards-incompatible so wait until v3.0.0.
# #
file = as_bytes(file.read()) file = _read_file_with_selectors(file, read_event)
else: else:
file.read = _wrap_function_with_callback( file.read = _wrap_function_with_callback(
file.read, file.read,
@@ -141,11 +174,13 @@ def _prepare_file_for_upload(
if isinstance(file, MultipartEncoder): if isinstance(file, MultipartEncoder):
return ChunkedMultipartUploadStream( return ChunkedMultipartUploadStream(
encoder=file, encoder=file,
event=read_event,
) )
else: else:
return ChunkedUploadStream( return ChunkedUploadStream(
stream=file, stream=file,
callback=callback, callback=callback,
event=read_event
) )
else: else:
return file return file