Visualization with Chinese in XGBoost and RandomForest

Table of Contents

XGBoost

import re
def set_graph_font(graph):
    graph.source = re.sub(r'graph \[ rankdir=TB \]\n\n    0 ', r'graph [ rankdir=TB ]\n\n node [fontname="FangSong" shape=plaintext]\n\n    0 ', graph.source)
    return graph.source
diagraph = xgb.to_graphviz(model, num_trees=9)
diagraph.format = 'png'
set_graph_font(diagraph)
diagraph

RandomForest

def plot_forest(model):
    for i, t  in enumerate(model.estimators_[:10]):
        # Export as dot file
        forest_graph = export_graphviz(estimator, out_file=f'forest/tree-{i}.dot', 
        feature_names = column_names,
        class_names = ["Class Name 1","Class Name 2"],
        rounded = True, proportion = False, 
        precision = 2, filled = True)
        
        with open(f"forest/tree-{i}.dot", "r", encoding="utf-8") as fd:
            source = fd.read()
            fd.close()
        source = re.sub(r"helvetica", r"FangSong", source)
        with open(f"forest/tree-cn-{i}.dot", "w", encoding="utf-8") as fd:
            fd.write(source)
            fd.close()
            
        # Convert to png using system command (requires Graphviz)
        call(['dot', '-Tpng', f'forest/tree-cn-{i}.dot', '-o', f'forest/tree-cn-{i}.png', '-Gdpi=600'])

        # Display in jupyter notebook
        # Image(filename = 'forest/tree-cn-{i}.png')

Leave a Reply

en_USEnglish