1
0
mirror of https://github.com/immich-app/immich.git synced 2025-06-15 03:30:33 +02:00

chore(ml): use strict mypy (#5001)

* improved typing

* improved export typing

* strict mypy & check export folder

* formatting

* add formatting checks for export folder

* re-added init call
This commit is contained in:
Mert
2023-11-13 11:18:46 -05:00
committed by GitHub
parent 9fa9ad05b1
commit 935f471ccb
10 changed files with 70 additions and 55 deletions

View File

@ -1,6 +1,7 @@
import tempfile
import warnings
from dataclasses import dataclass, field
from math import e
from pathlib import Path
import open_clip
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
output_path = Path(output_path)
def encode_image(image: torch.Tensor) -> torch.Tensor:
return model.encode_image(image, normalize=True)
output = model.encode_image(image, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
traced = torch.jit.trace(encode_image, args)
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
output_path = Path(output_path)
def encode_text(text: torch.Tensor) -> torch.Tensor:
return model.encode_text(text, normalize=True)
output = model.encode_text(text, normalize=True)
assert isinstance(output, torch.Tensor)
return output
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
traced = torch.jit.trace(encode_text, args)
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)