from __future__ import annotations
from http import HTTPStatus
from typing import Any, Generator
from urllib.parse import urljoin
import httpx
from modelz.console import console
from modelz.serde import Serde, SerdeEnum, TextSerde
TIMEOUT = httpx.Timeout(5, read=300, write=300)
DEFAULT_RESP_SERDE = TextSerde()
DEFAULT_RETRY = 3
class ModelzAuth(httpx.Auth):
def __init__(self, key: str) -> None:
self.key: str = key
if not self.key:
raise RuntimeError("cannot find the API key")
def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
request.headers["X-API-Key"] = self.key
yield request
[docs]class ModelzResponse:
"""Modelz internal response.
The initialization will raise an error if the response status code is not 200.
"""
def __init__(self, resp: httpx.Response, serde: Serde = DEFAULT_RESP_SERDE):
if resp.status_code != HTTPStatus.OK:
console.print(f"[bold red]err[{resp.status_code}][/bold red]: {resp.text}")
raise ValueError(f"inference err with code {resp.status_code}")
self.resp = resp
self.serde = serde
self._data = None
[docs] def save_to_file(self, file: str):
"""Save the response data to a file in binary format."""
with open(file, "wb") as f:
f.write(self.data)
@property
def data(self) -> Any:
"""Access the response data.
It will be decoded by the serde method provided.
"""
if not self._data:
self._data = self.serde.decode(self.resp.content)
return self._data
[docs] def show(self):
"""Display the response data in the console with color."""
console.print(self.data)
[docs]class ModelzClient:
"""Create a Modelz Client for standalone commands.
Args:
endpoint: endpoint URL
key: API key
timeout: request timeout (second)
"""
def __init__(
self,
key: str,
endpoint: str | None = None,
timeout: float | httpx.Timeout = TIMEOUT,
) -> None:
self.endpoint = endpoint
auth = ModelzAuth(key)
transport = httpx.HTTPTransport(retries=DEFAULT_RETRY)
self.client = httpx.Client(auth=auth, transport=transport)
self.serde: Serde
self.timeout = timeout
[docs] def inference(
self,
params: Any,
serde: str = "json",
) -> ModelzResponse:
"""Get the inference result.
Args:
params: request params, will be serialized by ``serde``
serde: serialize/deserialize method, choose from ("json", "msgpack", "raw")
"""
self.serde = SerdeEnum[serde.lower()].value()
with console.status(f"[bold green]Modelz {self.endpoint} inference..."):
resp = self.client.post(
urljoin(self.endpoint, "/inference"),
content=self.serde.encode(params),
timeout=self.timeout,
)
return ModelzResponse(resp, self.serde)
[docs] def metrics(self, deployment: str | None = None) -> ModelzResponse:
"""Get deployment metrics.
Args:
deployment: deployment ID
"""
deploy = deployment if deployment else self.deployment
assert deploy, "deployment is required"
with console.status(f"[bold green]Modelz {deploy} metrics..."):
resp = self.client.get(
urljoin(self.host.format(deploy), "/metrics"),
timeout=self.timeout,
)
return ModelzResponse(resp)
[docs] def build(self, repo: str):
"""Build a Docker image and push it to the registry.
Args:
repo: git repo url
"""
with console.status(f"[bold green]Modelz build {repo}..."):
resp = self.client.post(
urljoin(self.host.format("api"), "/build"),
timeout=self.timeout,
)
ModelzResponse(resp)
console.print(f"created the build job for repo [bold cyan]{repo}[/bold cyan]")