Table of Contents
XGBoost
import re
def set_graph_font(graph):
graph.source = re.sub(r'graph \[ rankdir=TB \]\n\n 0 ', r'graph [ rankdir=TB ]\n\n node [fontname="FangSong" shape=plaintext]\n\n 0 ', graph.source)
return graph.source
diagraph = xgb.to_graphviz(model, num_trees=9)
diagraph.format = 'png'
set_graph_font(diagraph)
diagraph
RandomForest
def plot_forest(model):
for i, t in enumerate(model.estimators_[:10]):
# Export as dot file
forest_graph = export_graphviz(estimator, out_file=f'forest/tree-{i}.dot',
feature_names = column_names,
class_names = ["Class Name 1","Class Name 2"],
rounded = True, proportion = False,
precision = 2, filled = True)
with open(f"forest/tree-{i}.dot", "r", encoding="utf-8") as fd:
source = fd.read()
fd.close()
source = re.sub(r"helvetica", r"FangSong", source)
with open(f"forest/tree-cn-{i}.dot", "w", encoding="utf-8") as fd:
fd.write(source)
fd.close()
# Convert to png using system command (requires Graphviz)
call(['dot', '-Tpng', f'forest/tree-cn-{i}.dot', '-o', f'forest/tree-cn-{i}.png', '-Gdpi=600'])
# Display in jupyter notebook
# Image(filename = 'forest/tree-cn-{i}.png')