You've already forked immich
mirror of
https://github.com/immich-app/immich.git
synced 2025-06-15 03:30:33 +02:00
chore(ml): use strict mypy (#5001)
* improved typing * improved export typing * strict mypy & check export folder * formatting * add formatting checks for export folder * re-added init call
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from math import e
|
||||
from pathlib import Path
|
||||
|
||||
import open_clip
|
||||
@ -69,10 +70,12 @@ def export_image_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig,
|
||||
output_path = Path(output_path)
|
||||
|
||||
def encode_image(image: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_image(image, normalize=True)
|
||||
output = model.encode_image(image, normalize=True)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
return output
|
||||
|
||||
args = (torch.randn(1, 3, model_cfg.image_size, model_cfg.image_size),)
|
||||
traced = torch.jit.trace(encode_image, args)
|
||||
traced = torch.jit.trace(encode_image, args) # type: ignore[no-untyped-call]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
@ -91,10 +94,12 @@ def export_text_encoder(model: open_clip.CLIP, model_cfg: OpenCLIPModelConfig, o
|
||||
output_path = Path(output_path)
|
||||
|
||||
def encode_text(text: torch.Tensor) -> torch.Tensor:
|
||||
return model.encode_text(text, normalize=True)
|
||||
output = model.encode_text(text, normalize=True)
|
||||
assert isinstance(output, torch.Tensor)
|
||||
return output
|
||||
|
||||
args = (torch.ones(1, model_cfg.sequence_length, dtype=torch.int32),)
|
||||
traced = torch.jit.trace(encode_text, args)
|
||||
traced = torch.jit.trace(encode_text, args) # type: ignore[no-untyped-call]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
|
Reference in New Issue
Block a user