Project Icon

traceml

机器学习数据追踪与可视化工具,支持多种深度学习框架

TraceML 是一款强大的工具,用于机器学习和数据的追踪、可视化、解释和漂移检测。它与 Keras、PyTorch、TensorFlow、Fastai、Pytorch Lightning 和 HuggingFace 等多种深度学习和机器学习框架集成,方便用户记录和跟踪实验数据。TraceML 支持离线模式、多种数据可视化接口,并能生成详细的数据框架总结。

许可证: Apache 2 TraceML Slack 文档 GitHub GitHub

TraceML

Polyaxon 的 ML/数据跟踪、可视化、可解释性、漂移检测和仪表板引擎。

安装

pip install traceml

如果您想使用跟踪功能,还需要安装 polyaxon

pip install polyaxon traceml

[WIP] 本地沙盒

即将推出

离线使用

您可以启用离线模式以在没有 API 的情况下跟踪运行:

export POLYAXON_OFFLINE="true"

或者传递离线标志

from traceml import tracking

tracking.init(..., is_offline=True, ...)

在 Python 脚本中的简单使用

import random

import traceml as tracking

tracking.init(
    is_offline=True,
    project='quick-start',
    name="my-new-run",
    description="trying TraceML",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

# 跟踪一些数据引用
tracking.log_data_ref(content=X_train, name='x_train')
tracking.log_data_ref(content=y_train, name='y_train')

# 跟踪输入
tracking.log_inputs(
    batch_size=64,
    dropout=0.2,
    learning_rate=0.001,
    optimizer="Adam"
)

def get_loss(step):
    result = 10 / (step + 1)
    noise = (random.random() - 0.5) * 0.5 * result
    return result + noise

# 跟踪指标
for step in range(100):
    loss = get_loss(step)
    tracking.log_metrics(
    loss=loss,
    accuracy=(100 - loss) / 100.0,
)

# 跟踪一些一次性的结果
tracking.log_outputs(validation_score=0.66)

# 可选择手动停止跟踪过程
tracking.stop()

与深度学习和机器学习库和框架的集成

Keras

您可以使用 TraceML 的回调自动保存所有指标并收集输出和模型,还可以使用日志方法跟踪其他信息:

from traceml import tracking
from traceml.integrations.keras import Callback

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="keras-run",
    description="trying TraceML & Keras",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

tracking.log_inputs(
    batch_size=64,
    dropout=0.2,
    learning_rate=0.001,
    optimizer="Adam"
)
tracking.log_data_ref(content=x_train, name='x_train')
tracking.log_data_ref(content=y_train, name='y_train')
tracking.log_data_ref(content=x_test, name='x_test')
tracking.log_data_ref(content=y_test, name='y_test')

# ...

model.fit(
    x_train,
    y_train,
    validation_data=(X_test, y_test),
    epochs=epochs,
    batch_size=100,
    callbacks=[Callback()],
)

PyTorch

您可以使用跟踪模块记录 Pytorch 实验的指标、输入和输出:

from traceml import tracking

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="pytorch-run",
    description="trying TraceML & PyTorch",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

tracking.log_inputs(
    batch_size=64,
    dropout=0.2,
    learning_rate=0.001,
    optimizer="Adam"
)

# 指标
for batch_idx, (data, target) in enumerate(train_loader):
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    tracking.log_metrics(loss=loss)

asset_path = tracking.get_outputs_path('model.ckpt')
torch.save(model.state_dict(), asset_path)

# 记录模型
tracking.log_artifact_ref(asset_path, framework="pytorch", ...)

Tensorflow

您可以使用跟踪模块记录 Tensorflow 和分布式 Tensorflow 实验的指标、输出和模型:

from traceml import tracking
from traceml.integrations.tensorflow import Callback

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="tf-run",
    description="trying TraceML & Tensorflow",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

tracking.log_inputs(
    batch_size=64,
    dropout=0.2,
    learning_rate=0.001,
    optimizer="Adam"
)

# 记录模型
estimator.train(hooks=[Callback(log_image=True, log_histo=True, log_tensor=True)])

Fastai

您可以使用跟踪模块记录 Fastai 实验的指标、输出和模型:

from traceml import tracking
from traceml.integrations.fastai import Callback

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="fastai-run",
    description="trying TraceML & Fastai",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

# 记录模型指标
learn.fit(..., cbs=[Callback()])

Pytorch Lightning

您可以使用跟踪模块记录 Pytorch Lightning 实验的指标、输出和模型:

from traceml import tracking
from traceml.integrations.pytorch_lightning import Callback

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="pytorch-lightning-run",
    description="trying TraceML & Lightning",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

...
trainer = pl.Trainer(
    gpus=0,
    progress_bar_refresh_rate=20,
    max_epochs=2,
    logger=Callback(),
)

HuggingFace

您可以使用跟踪模块记录 HuggingFace 实验的指标、输出和模型:

from traceml import tracking
from traceml.integrations.hugging_face import Callback

tracking.init(
    is_offline=True,
    project='tracking-project',
    name="hg-run",
    description="trying TraceML & HuggingFace",
    tags=["examples"],
    artifacts_path="path/to/artifacts/repo"
)

...
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    callbacks=[Callback],
    # ...
)

跟踪工件

import altair as alt
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
from bokeh.plotting import figure
from vega_datasets import data

from traceml import tracking


def plot_mpl_figure(step):
    np.random.seed(19680801)
    data = np.random.randn(2, 100)

    figure, axs = plt.subplots(2, 2, figsize=(5, 5))
    axs[0, 0].hist(data[0])
    axs[1, 0].scatter(data[0], data[1])
    axs[0, 1].plot(data[0], data[1])
    axs[1, 1].hist2d(data[0], data[1])

    tracking.log_mpl_image(figure, 'mpl_image', step=step)


def log_bokeh(step):
    factors = ["a", "b", "c", "d", "e", "f", "g", "h"]
    x = [50, 40, 65, 10, 25, 37, 80, 60]

    dot = figure(title="Categorical Dot Plot", tools="", toolbar_location=None,
                 y_range=factors, x_range=[0, 100])

    dot.segment(0, factors, x, factors, line_width=2, line_color="green", )
    dot.circle(x, factors, size=15, fill_color="orange", line_color="green", line_width=3, )

    factors = ["foo 123", "bar:0.2", "baz-10"]
    x = ["foo 123", "foo 123", "foo 123", "bar:0.2", "bar:0.2", "bar:0.2", "baz-10", "baz-10",
         "baz-10"]
    y = ["foo 123", "bar:0.2", "baz-10", "foo 123", "bar:0.2", "baz-10", "foo 123", "bar:0.2",
         "baz-10"]
    colors = [
        "#0B486B", "#79BD9A", "#CFF09E",
        "#79BD9A", "#0B486B", "#79BD9A",
        "#CFF09E", "#79BD9A", "#0B486B"
    ]

    hm = figure(title="Categorical Heatmap", tools="hover", toolbar_location=None,
                x_range=factors, y_range=factors)

    hm.rect(x, y, color=colors, width=1, height=1)

    tracking.log_bokeh_chart(name='confusion-bokeh', figure=hm, step=step)


def log_altair(step):
    source = data.cars()

    brush = alt.selection(type='interval')

    points = alt.Chart(source).mark_point().encode(
        x='Horsepower:Q',
        y='Miles_per_Gallon:Q',
        color=alt.condition(brush, 'Origin:N', alt.value('lightgray'))
    ).add_selection(
        brush
    )

    bars = alt.Chart(source).mark_bar().encode(
        y='Origin:N',
        color='Origin:N',
        x='count(Origin):Q'
    ).transform_filter(
        brush
    )

    chart = points & bars

    tracking.log_altair_chart(name='altair_chart', figure=chart, step=step)


def log_plotly(step):
    df = px.data.tips()

    fig = px.density_heatmap(df, x="total_bill", y="tip", facet_row="sex", facet_col="smoker")
    tracking.log_plotly_chart(name="2d-hist", figure=fig, step=step)


plot_mpl_figure(100)
log_bokeh(100)
log_altair(100)
log_plotly(100)

数据帧跟踪

摘要

一个扩展 pandas 数据帧描述功能的扩展。

该模块包含了 DataFrameSummary 对象,它扩展了 describe() 函数:

  • 属性
    • dfs.columns_stats:每列的计数、唯一值、缺失值、缺失百分比和类型
    • dsf.columns_types:列类型的计数
    • dfs[column]:更深入的列摘要
  • 功能
    • summary():扩展了 describe() 函数,包含 columns_stats 的值

DataFrameSummary 期望获得一个 pandas DataFrame 进行总结。

from traceml.summary.df import DataFrameSummary

dfs = DataFrameSummary(df)

获取列类型

dfs.columns_types


numeric     9
bool        3
categorical 2
unique      1
date        1
constant    1
dtype: int64

获取列统计

dfs.columns_stats


                      A            B        C              D              E
counts             5802         5794     5781           5781           4617
uniques            5802            3     5771            128            121
missing               0            8       21             21           1185
missing_perc         0%        0.14%    0.36%          0.36%         20.42%
types            unique  categorical  numeric        numeric        numeric

获取单个列摘要,例如数值列

# 我们也可以用数字 A[1] 访问列
dfs['A']

std                                                                 0.2827146
max                                                                  1.072792
min                                                                         0
variance                                                           0.07992753
mean                                                                0.5548516
5%                                                                  0.1603367
25%                                                                 0.3199776
50%                                                                 0.4968588
75%                                                                 0.8274732
95%                                                                  1.011255
iqr                                                                 0.5074956
kurtosis                                                            -1.208469
skewness                                                            0.2679559
sum                                                                  3207.597
mad                                                                 0.2459508
cv                                                                  0.5095319
zeros_num                                                                  11
zeros_perc                                                               0,1%
deviating_of_mean                                                          21
deviating_of_mean_perc                                                  0.36%
deviating_of_median                                                        21
deviating_of_median_perc                                                0.36%
top_correlations                         {u'D': 0.702240243124, u'E': -0.663}
counts                                                                   5781
uniques                                                                  5771
missing                                                                    21
missing_perc                                                            0.36%
types                                                                 numeric
Name: A, dtype: object

[进行中] 摘要

  • 添加列间的总结分析,例如 dfs[[1, 2]]

[进行中] 可视化

  • 添加用 matplotlib 的总结可视化。
  • 添加用 plotly 的总结可视化。
  • 添加用 altair 的总结可视化。
  • 添加预定义的分析报告。

[进行中] 目录和版本

  • 添加持久化摘要和链接到特定版本的可能性。
  • 集成质量库。
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

稿定AI

稿定设计 是一个多功能的在线设计和创意平台,提供广泛的设计工具和资源,以满足不同用户的需求。从专业的图形设计师到普通用户,无论是进行图片处理、智能抠图、H5页面制作还是视频剪辑,稿定设计都能提供简单、高效的解决方案。该平台以其用户友好的界面和强大的功能集合,帮助用户轻松实现创意设计。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号